import paramiko
from io import StringIO
import os
import time
from utils.custom_log import log_start
logger = log_start('sshops')


class SSH():
    def __init__(self, ip, port, username, password=None, key=None):
        self.host = ip
        self.port = port
        self.username = username
        self.password = password
        self.key = key
        self.timeout = 60
    
    def open(self):
        info = "> ssh {0}@{1}  -p ['{2}']".format(self.username,self.host,self.port)
        logger.info(info)
        self.sshclient = paramiko.SSHClient()
        self.sshclient.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        if self.key:
            if len(self.key)>255:
                privateByte = StringIO(self.key)  # 将文本密钥转为文件对象
                private_key = paramiko.RSAKey.from_private_key(privateByte)
            else:  
                private_key = paramiko.RSAKey.from_private_key_file(self.key)  # 本地私钥: /root/.ssh/id_rsa
            self.connect_key(private_key)
        else:
            self.connect_pwd()

    def close(self):
        try:
            self.sshclient.close()
        except Exception as e:
            error_msg = "SSH关闭失败:{}，请检查!".format(str(e))
            logger.error(error_msg)
            raise ValueError(error_msg)
    
    
    def connect_key(self,private_key):
        '''
        SSH私钥连接
        '''
        try:
            self.sshclient.connect(
                hostname=self.host,
                port=self.port,
                username=self.username, 
                pkey=private_key, 
                timeout=self.timeout,
                allow_agent=False, 
                look_for_keys=False)
            return self.sshclient
        except Exception as e:
            error_msg = "SSH连接失败:{}，请检查重试!".format(e)
            logger.error(error_msg)
            raise ValueError(error_msg)

    def connect_pwd(self):
        '''
        使用密码连接
        '''
        try:
            self.sshclient.connect(
                hostname=self.host,
                port=self.port,
                username=self.username,
                password=self.password, 
                 timeout=self.timeout,
                allow_agent=False, 
                look_for_keys=False )
            return self.sshclient
        except Exception as e:
            error_msg = "SSH连接失败:{}，请检查".format(e)
            logger.error(error_msg)
            raise ValueError(error_msg)

    def command(self, command):
        '''
        远程执行命令
        '''
        stdin, stdout, stderr = self.sshclient.exec_command(command)
        stdout = stdout.read().decode()
        error = stderr.read()
        if not error:
            return {'code': 200, 'msg': '执行命令成功!', 'data': stdout}
        else:
            return {'code': 500, 'msg': '执行命令失败!错误:%s' %error}

    def sftp_put_file(self, local_file,remote_file):
        '''
        文件远程上传到服务器
        '''
        try:
            transport = self.sshclient.get_transport()
            sftp = paramiko.SFTPClient.from_transport(transport)
            # remote_file = os.path.join(remote_dir, os.path.basename(local_file))
            sftp.put(local_file, remote_file)
            logger.info(f"上传本地文件{local_file} 到 {self.host}:{remote_file}")
            return {'code': 200, 'msg': '上传文件成功!', 'data': f"{local_file}"}

        except Exception as e:
            error_msg = f"上传本地文件失败:{local_file} / {remote_file} 返回：{e}"
            logger.error(error_msg)
            return {'code': 500, 'msg': f"执行命令失败!错误:{error_msg}" }
            
    def sftp_get_file(self, file, local_dir, remote_file):
        '''
        下载远程文件到本地
        '''
        try:
            transport = self.sshclient.get_transport()
            sftp = paramiko.SFTPClient.from_transport(transport)
            sftp.get(remote_file, os.path.join(local_dir, file))
            logger.info('下载远端文件命令：%s' % remote_file)
        except Exception as e:
            error_msg = '下载远端文件失败：%s'% remote_file
            logger.error(error_msg)
            raise ValueError(error_msg)

    ## 测试连通性
    def test(self):
        result = self.command('uptime')
        return result
    
    ## with语句是调用,会话管理器
    def __enter__(self):
        '''
        使用with语句是调用,会话管理器在代码块开始前调用,
        返回值与as后的参数绑定
        '''
        self.open()
        return self
    def __exit__(self, exc_type, exc_val, exc_tb):
        '''
        会话管理器在代码块执行完成好后调用(不同于__del__)(必须是4个参数)
        '''
        self.close()

if __name__ == "__main__":
    with SSH('115.192.116.88',55555,'root',key="/root/.ssh/id_rsa") as ssh:
        result = ssh.test()
        print(result)

        # result = ssh.command('ls -l')
        # local_file = os.path.join(os.getcwd(), 'pagination.py')
        # result = ssh.sftp_put_file(local_file, '/tmp/pagination.py')
        # print(result)